import numpy as np
import matplotlib.pyplot as plt

from sklearn import *


def plot_decision_boundary(model, axis):
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),
        np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]
    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)
    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
    plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)


iris = datasets.load_iris()

X = iris.data
y = iris.target

X = X[y < 2, :2]
y = y[y < 2]

plt.scatter(X[y == 0, 0], X[y == 0, 1], color='red')
plt.scatter(X[y == 1, 0], X[y == 1, 1], color='blue')
plt.show()

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2, random_state=666)

# 标准化
from sklearn.preprocessing import StandardScaler

'''
X_scaled = (X - X.mean()) / X.std()
'''
standardScaler = StandardScaler()
standardScaler.fit(X_train)
X_standard = standardScaler.transform(X_train)
x_test_standard = standardScaler.transform(X_test)

from sklearn.svm import LinearSVC

svc = LinearSVC(C=1e9)
svc.fit(X_train, y_train)

plot_decision_boundary(svc, axis=[0, 10, 0, 10])
plt.scatter(X_train[ y_train == 0, 0], X_train[ y_train == 0, 1])
plt.scatter(X_train[ y_train == 1, 0], X_train[ y_train== 1, 1])
plt.show()

# y_predict = svc.predict(x_test_standard)
# score = svc.score(y_predict, y_test)
# print(f'分数：{score}')


def plot_svc_decision_boundary(model, axis):
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),
        np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])

    plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

    w = model.coef_[0]
    b = model.intercept_[0]

    # w0*x0 + w1*x1 + b = 0
    # => x1 = -w0/w1 * x0 - b/w1
    plot_x = np.linspace(axis[0], axis[1], 200)
    up_y = -w[0] / w[1] * plot_x - b / w[1] + 1 / w[1]
    down_y = -w[0] / w[1] * plot_x - b / w[1] - 1 / w[1]

    up_index = (up_y >= axis[2]) & (up_y <= axis[3])
    down_index = (down_y >= axis[2]) & (down_y <= axis[3])
    plt.plot(plot_x[up_index], up_y[up_index], color='black')
    plt.plot(plot_x[down_index], down_y[down_index], color='black')



plot_svc_decision_boundary(svc, axis=[0, 10, 0, 10])
plt.scatter(X_train[ y_train == 0, 0], X_train[ y_train == 0, 1])
plt.scatter(X_train[ y_train == 1, 0], X_train[ y_train== 1, 1])
plt.show()